{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 基于模式的使用规则进行重写\n",
    "\n",
    "参考：[rewrite_patterns](https://onnxscript.ai/tutorial/rewriter/rewrite_patterns.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ONNX 重写工具为用户提供了一个功能，可以根据用户提供的重写规则，将 ONNX 计算图中的某些模式替换为另一种模式。\n",
    "\n",
    "## 使用方法\n",
    "\n",
    "在计算图重写模式时，需要三个主要部分：\n",
    "\n",
    "- `target_pattern`：要匹配的原始模式。这个模式使用类似 ONNXScript 的算子编写函数。\n",
    "- `replacement_pattern`：用于替换原始模式的模式。这个模式也使用类似 ONNXScript 的算子编写函数。\n",
    "- `match_condition`（可选）：只有满足匹配条件时，才会进行模式重写。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 简单的例子\n",
    "\n",
    "一个简单示例，演示了如何使用 GELU 激活函数的此功能：\n",
    "\n",
    "可以使用给定公式中的高斯误差函数来计算 GELU 激活函数："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\n",
    "\\text{GELU} = x\\Phi(x) = x \\cdot \\frac{1}{2} [1 + \\text{erf}(x / \\sqrt{2})]\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxscript.rewriter import pattern\n",
    "from onnxscript import ir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用 `onnxscript` 算子创建需要替换的目标模式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "def erf_gelu_pattern(op, x):\n",
    "    return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "之后，创建由 GELU `onnxscript` 算子组成的替换模式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gelu(op, x: ir.Value):\n",
    "    return op.Gelu(x, _domain=\"com.microsoft\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{note}\n",
    "替换模式的输入是 `ir.Value` 类型。\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在这个例子中，我们不需要 `match_condition`，所以暂时跳过这个选项。然后使用 `RewriteRule` 函数创建重写规则。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "rule = pattern.RewriteRule(\n",
    "    erf_gelu_pattern,  # Target Pattern\n",
    "    gelu,  # Replacement Pattern\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在重写规则已经创建，下一步是应用这些基于模式的重写规则。`rewriter.rewrite` 调用包含三个主要部分：\n",
    "\n",
    "- `model`：要应用模式重写规则的原始模型。这是 `onnx.ModelProto` 类型。\n",
    "- `function_rewrite_rules`：（可选）此参数用于传递基于函数名称的重写规则。如何使用此参数的步骤将在另一个教程中介绍。此参数是 `Sequence[type[FunctionRewriteRule]]` 类型。\n",
    "- `pattern_rewrite_rules`：（可选）此参数用于传递基于提供的替换模式的重写规则。在本教程中，我们将仅使用此参数与model结合。此参数是以下类型之一：\n",
    "    - `Sequence[PatternRewriteRule]`\n",
    "    - `RewriteRuleSet`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{note}\n",
    "`pattern_rewrite_rules` 接受 `PatternRewriteRule` 类型的序列，或者 `RewriteRuleSet`，后者本质上也是使用 `PatternRewriteRule` 类型的序列创建的规则集。因此，如果要传递单个重写规则，需要将其作为序列的一部分传递。有关如何创建和使用规则集的步骤，请参阅“[使用不同模式创建规则集](https://onnxscript.ai/tutorial/rewriter/rewrite_patterns.html#heading-target-commute-ruleset)”部分中的示例。\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面的代码片段演示了如何使用 `rewriter.rewrite` 调用上述创建的重写规则："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_rewrite(model):\n",
    "    rule = pattern.RewriteRule(\n",
    "        erf_gelu_pattern,  # Target Pattern\n",
    "        gelu,  # Replacement\n",
    "    )\n",
    "    model_with_rewrite_applied = onnxscript.rewriter.rewrite(\n",
    "        model,\n",
    "        pattern_rewrite_rules=[rule],\n",
    "    )\n",
    "    return model_with_rewrite_applied"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 利用 `commute` 参数进行模式匹配"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用不同模式创建规则集\n",
    "\n",
    "此方法需要创建两个单独的规则，并将它们打包成 `PatternRewriteRules` 的序列或 `RewriteRuleSet`。创建 `RewriteRuleSet` 是首选选项，但两者都可以使用。为了创建一个包含多个规则（例如 `rule1` 和 `rule2`）的 `RewriteRuleSet`：\n",
    "```python\n",
    "from onnxscript.rewriter import pattern\n",
    "rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2])\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了将此方法应用于上述示例，首先创建两个单独的目标模式，如下所示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def erf_gelu_pattern(op, x):\n",
    "    return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))\n",
    "\n",
    "def erf_gelu_pattern_2(op, x):\n",
    "    return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后，为每个目标模式创建两个单独的 `PatternRewriteRules`。将这些规则打包成一个 `RewriteRuleSet` 对象，并通过传递创建的 `RewriteRuleSet` 作为 `pattern_rewrite_rules` 参数来应用重写。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_rewrite_with_ruleset(model):\n",
    "    # Create multiple rules\n",
    "    rule1 = pattern.RewriteRule(\n",
    "        erf_gelu_pattern,  # Target Pattern\n",
    "        gelu,  # Replacement\n",
    "    )\n",
    "    rule2 = pattern.RewriteRule(\n",
    "        erf_gelu_pattern_2,  # Target Pattern\n",
    "        gelu,  # Replacement\n",
    "    )\n",
    "    # Create a Rewrite Rule Set with multiple rules.\n",
    "    rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2])\n",
    "    # Apply rewrites\n",
    "    model_with_rewrite_applied = onnxscript.rewriter.rewrite(\n",
    "        model,\n",
    "        pattern_rewrite_rules=rewrite_rule_set,\n",
    "        # pattern_rewrite_rules=[rule1, rule2], # Alternative method of passing multiple rules\n",
    "    )\n",
    "    return model_with_rewrite_applied"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 在创建规则时使用 `commute` 参数\n",
    "\n",
    "为相似模式创建多个目标模式可能会很繁琐。为了避免这种情况，可以在创建 `RewriteRuleSet` 时利用 `commute` 参数。只需设置 `commute=True`，即可避免为因交换性而不同的模式创建多个目标模式。满足交换性属性的不同模式的多个规则会自动打包成 `RewriteRuleSet` 对象。然后通过传递创建的 `RewriteRuleSet` 作为 `pattern_rewrite_rules` 参数来应用重写。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_rewrite_with_commute(model):\n",
    "    rule = pattern.RewriteRule(\n",
    "        erf_gelu_pattern,  # Target Pattern\n",
    "        gelu,  # Replacement\n",
    "    )\n",
    "    # Create a Rewrite Rule Set with commute=True\n",
    "    rewrite_rule_set = pattern.RewriteRuleSet([rule], commute=True)\n",
    "    # Apply rewrites\n",
    "    model_with_rewrite_applied = onnxscript.rewriter.rewrite(\n",
    "        model,\n",
    "        pattern_rewrite_rules=rewrite_rule_set,\n",
    "    )\n",
    "    return model_with_rewrite_applied"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  使用 `match_condition` 参数进行模式匹配\n",
    "\n",
    "本节将讨论如何利用 `match_condition` 参数。`match_condition` 参数检查模式是否在考虑某些约束的情况下与目标模式匹配。\n",
    "\n",
    "基于 [ONNX Matmul 规范](https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul)，onnx Matmul的行为类似于 `numpy.matmul`，并且也遵循 `numpy` 广播。因此，在这个特定模式中，如果 `matmul` 广播满足，那么我们不需要 `reshapes`。为了验证这一点，我们需要检查以下内容：\n",
    "\n",
    "- 输入形状检查：`input_a` 和 `input_b` 应该是可广播的\n",
    "- 输出形状检查：`shape_c` 应该与从 `matmul(input_a, input_b)` 得到的输出形状相同"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果上述情况为真，那么我们不需要 `reshapes`，可以使用基于模式的重写来消除它们。\n",
    "\n",
    "首先，以类似于第一个示例的方式编写目标模式和替换模式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def two_reshapes_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_b, shape_c):\n",
    "    reshape_a = op.Reshape(input_a, shape_a)\n",
    "    reshape_b = op.Reshape(input_b, shape_b)\n",
    "    matmul = op.MatMul(reshape_a, reshape_b)\n",
    "    return op.Reshape(matmul, shape_c)\n",
    "\n",
    "def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_):\n",
    "    return op.MatMul(input_a, input_b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{note}\n",
    "在这种情况下，目标模式有5个输入：`input_a`、`input_b`、`shape_a`、`shape_b`、`shape_c`。然而，替换模式仅利用了 `input_a` 和 `input_b`。为了避免在替换模式签名中引用所有未使用的参数，只传递 `input_a` 和 `input_b`，并使用 `**_` 来表示所有未使用的参数。\n",
    "\n",
    "同样，在编写条件检查函数时，我们只需要 `input_a`、`input_b` 和 `shape_c`。在条件匹配函数签名中使用 `**_` 来表示所有未使用的参数。\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了验证 `matmul` 广播是否满足，我们编写一个条件检查函数，如下所示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_if_not_need_reshape(\n",
    "    context, input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_\n",
    ") -> bool:\n",
    "    \"\"\"Condition to check if we need to replace the pattern.\n",
    "\n",
    "    If matmul broadcasting is enough, then we don't need the reshapes.\n",
    "\n",
    "    To validate this, we need to check the following:\n",
    "    1. Input shapes check: input_a and input_b should be broadcastable\n",
    "    2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b)\n",
    "\n",
    "    If the above are true, then we don't need the reshapes.\n",
    "\n",
    "    Returns:\n",
    "        True if we need to replace the pattern, False otherwise.\n",
    "    \"\"\"\n",
    "    del context  # Reserved for future extensions\n",
    "\n",
    "    input_a_shape = input_a.shape\n",
    "    input_b_shape = input_b.shape\n",
    "    # TODO: Get a helper func to get const_value\n",
    "    _ir_utils.propagate_const_value(shape_c)\n",
    "    shape_c_tensor = shape_c.const_value\n",
    "    if shape_c_tensor is None:\n",
    "        logger.info(\"The value 'shape_c' is not statically known.\")\n",
    "        return False\n",
    "\n",
    "    if len(shape_c_tensor.shape) != 1:\n",
    "        logger.info(\n",
    "            \"Unexpected final shape. The shape of 'shape' value is %s\",\n",
    "            shape_c_tensor.shape,\n",
    "        )\n",
    "        return False\n",
    "\n",
    "    # NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape\n",
    "    # information. So, we need to check if the shape is None and return False.\n",
    "    if input_a_shape is None or input_b_shape is None:\n",
    "        logger.info(\"Shape information is not available for the inputs and outputs.\")\n",
    "        return False\n",
    "    input_a_shape = input_a_shape.numpy()\n",
    "    input_b_shape = input_b_shape.numpy()\n",
    "    shape_c = shape_c_tensor.numpy().tolist()\n",
    "\n",
    "    a_rank = len(input_a_shape)\n",
    "    b_rank = len(input_b_shape)\n",
    "\n",
    "    # TODO(justinchuby): Check shape size\n",
    "\n",
    "    # 1. Check if input shapes are broadcastable\n",
    "    # 1.a. If the first input is 1-D, check whether\n",
    "    # the dim matches the last second dim of the second input.\n",
    "    mimic_matmul_broadcast_behavior = False\n",
    "    if a_rank < 2:\n",
    "        if b_rank < 2:\n",
    "            logger.info(\"Optimization of dot product is not supported yet.\")\n",
    "            return False\n",
    "        if input_a_shape[-1] != input_b_shape[-2]:\n",
    "            logger.info(\"Original shape is not MatMul compatible.\")\n",
    "            return False\n",
    "        else:\n",
    "            input_a_shape = [1, *input_a_shape]\n",
    "            a_rank = len(input_a_shape)\n",
    "            mimic_matmul_broadcast_behavior = True\n",
    "    # 1.b. If the second input is 1-D, check whether\n",
    "    # the dim matches the last dim of the first input.\n",
    "    if b_rank < 2:\n",
    "        if input_b_shape[-1] != input_a_shape[-1]:\n",
    "            logger.info(\"Original shape is not MatMul compatible.\")\n",
    "            return False\n",
    "        else:\n",
    "            input_b_shape = [*input_b_shape, 1]\n",
    "            b_rank = len(input_b_shape)\n",
    "            mimic_matmul_broadcast_behavior = True\n",
    "    # 1.c. If both inputs are at least 2-D, check whether\n",
    "    # the last dimension of the first input matches the second\n",
    "    # last dimension of the second input, and shape[:-2] are\n",
    "    # broadcastable.\n",
    "    input_a_shape_except_second_last_dim = [*input_a_shape[:-2], *[input_a_shape[-1]]]\n",
    "    input_b_shape_except_last_dim = input_b_shape[:-1]\n",
    "    broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]]\n",
    "    for idx, (dim_from_a, dim_from_b) in enumerate(\n",
    "        zip(\n",
    "            reversed(input_a_shape_except_second_last_dim),\n",
    "            reversed(input_b_shape_except_last_dim),\n",
    "        )\n",
    "    ):\n",
    "        if dim_from_a not in {1, dim_from_b}:\n",
    "            logger.info(\"Original shape is not broadcastable.\")\n",
    "            return False\n",
    "        elif idx > 0:\n",
    "            broadcast_matmul_output_shape = [\n",
    "                max(dim_from_a, dim_from_b),\n",
    "                *broadcast_matmul_output_shape,\n",
    "            ]\n",
    "\n",
    "    # 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b)\n",
    "    # Prepend the broadcast_matmul_output_shape with the longer shape of input\n",
    "    if a_rank > b_rank:\n",
    "        longer_shape = input_a_shape\n",
    "        shorter_shape = input_b_shape\n",
    "    else:\n",
    "        longer_shape = input_b_shape\n",
    "        shorter_shape = input_a_shape\n",
    "    broadcast_matmul_output_shape = [\n",
    "        *longer_shape[: -len(shorter_shape)],\n",
    "        *broadcast_matmul_output_shape,\n",
    "    ]\n",
    "    if mimic_matmul_broadcast_behavior and b_rank == 2 and input_b_shape[-1] == 1:\n",
    "        # If input_b is expanded to 2-D, then we need to remove the last dimension\n",
    "        broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1]\n",
    "    if mimic_matmul_broadcast_behavior and a_rank == 2 and input_a_shape[0] == 1:\n",
    "        # If input_a is expanded to 2-D, then we need to remove the first dimension\n",
    "        # of input_a, which would be the -2nd dimension of the output shape.\n",
    "        broadcast_matmul_output_shape.pop(-2)\n",
    "    if shape_c != broadcast_matmul_output_shape:\n",
    "        logger.info(\n",
    "            \"Final output shape is not the same. Expected %s vs actual %s\",\n",
    "            shape_c,\n",
    "            broadcast_matmul_output_shape,\n",
    "        )\n",
    "        return False\n",
    "\n",
    "    return True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "所有必要的组件都准备好了，现在创建带有 `match_condition` 函数的模式重写规则，然后调用 `rewriter.rewrite` 来应用重写。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_rewrite(model):\n",
    "    # Create rewrite rules\n",
    "    two_reshapes_matmul_reshape_rule = pattern.RewriteRule(\n",
    "        two_reshapes_matmul_reshape_pattern,  # target pattern\n",
    "        matmul_pattern,  # replacement pattern\n",
    "        check_if_not_need_reshape,  # match_condition function\n",
    "    )\n",
    "    # Create a Rewrite Rule Set\n",
    "    rewrite_rule_set = pattern.RewriteRuleSet([two_reshapes_matmul_reshape_rule])\n",
    "    # Apply rewrite while passing match_condition\n",
    "    model_with_rewrite = onnxscript.rewriter.rewrite(\n",
    "        model,\n",
    "        pattern_rewrite_rules=rewrite_rule_set,\n",
    "    )\n",
    "    return model_with_rewrite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xxx",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
